application: targetted data collection

knowing what we know, where and when should we plan to next collect data?

planning the next test

survival analysis

Code
library(cmdstanr)

survival_model <- cmdstan_model(stan_file = "survival.stan")
survival_model$format()
data {
  int<lower=0> n_meas; // number of observations
  vector<lower=0>[n_meas] obs_time; // time of observation
  vector<lower=0>[n_meas] fail_lb; // lower bound of failure time
  vector<lower=0>[n_meas] fail_ub; // status of observation
  
  array[n_meas] int<lower=0, upper=1> fail_status; // if a failure has occured, we have interval-censored data
  
  int<lower=0> n_pred; // number of predictions
  vector<lower=0>[n_pred] pred_time; // time of prediction
}
parameters {
  real<lower=0> scale; // scale parameter
  real<lower=0> shape; // shape parameter
}
model {
  //priors
  scale ~ normal(8, 3);
  shape ~ normal(6, 3);
  
  //likelihood
  for (n in 1 : n_meas) {
    if (fail_status[n] == 0) {
      target += log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      target += log(loglogistic_cdf(fail_ub[n] | scale, shape)
                    - loglogistic_cdf(fail_lb[n] | scale, shape));
    }
  }
}
generated quantities {
  vector[n_meas] log_lik;
  vector[n_pred] p_fail_pred;
  
  for (n in 1 : n_meas) {
    if (fail_status[n] == 1) {
      log_lik[n] = log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      log_lik[n] = log(loglogistic_cdf(fail_ub[n] | scale, shape)
                       - loglogistic_cdf(fail_lb[n] | scale, shape));
    }
  }
  
  for (n in 1 : n_pred) {
    p_fail_pred[n] = loglogistic_cdf(pred_time[n] | scale, shape);
  }
}
Code
import cmdstanpy

survival_model = cmdstanpy.CmdStanModel(stan_file = "survival.stan")
INFO:cmdstanpy:found newer exe file, not recompiling
Code
stan_code = survival_model.code()

from pygments import highlight
from pygments.lexers import StanLexer
from pygments.formatters import NullFormatter

formatted_stan_code = highlight(stan_code, StanLexer(), NullFormatter())

print(formatted_stan_code)
data {
  int <lower = 0> n_meas;                   // number of observations
  vector <lower = 0> [n_meas] obs_time;     // time of observation
  vector <lower = 0> [n_meas] fail_lb;      // lower bound of failure time
  vector <lower = 0> [n_meas] fail_ub;      // status of observation

  array [n_meas] int<lower = 0, upper = 1> fail_status; // if a failure has occured, we have interval-censored data

  int <lower = 0> n_pred;                   // number of predictions
  vector <lower = 0> [n_pred] pred_time;    // time of prediction
}

parameters {
  real <lower = 0> scale; // scale parameter
  real <lower = 0> shape; // shape parameter
}

model{
    //priors
    scale ~ normal(8, 3);
    shape ~ normal(6, 3);

    //likelihood
    for(n in 1:n_meas){
        if(fail_status[n] == 0){
            target += log1m(loglogistic_cdf(obs_time[n] | scale, shape));
        } else {
            target += log(
                          loglogistic_cdf(fail_ub[n] | scale, shape) - 
                          loglogistic_cdf(fail_lb[n] | scale, shape)
                        );
        }
    }
}

generated quantities {
  vector [n_meas] log_lik;
  vector [n_pred] p_fail_pred;

  for(n in 1:n_meas){
    if(fail_status[n] == 1){
      log_lik[n] = log1m(loglogistic_cdf(obs_time[n] | scale, shape));
    } else {
      log_lik[n] = log(
                        loglogistic_cdf(fail_ub[n] | scale, shape) - 
                        loglogistic_cdf(fail_lb[n] | scale, shape)
                      );
    }
  }

  for(n in 1:n_pred){
    p_fail_pred[n] = loglogistic_cdf(pred_time[n] | scale, shape);
  }
  
}
Code
using Turing, Random
using LogExpFunctions: log1mexp

include("../../data/LogLogisticDistribution.jl")
LogLogisticDistribution (generic function with 1 method)
Code

@model function loglogistic_survival(
    obs_time::Vector{Float64},     # time of observation
    fail_lb::Vector{Float64},      # lower bound of failure time
    fail_ub::Vector{Float64},      # upper bound of failure time
    fail_status::Vector{Int}   # 0 if right-censored, 1 if interval-censored
)
    # Priors
    scale ~ Normal(8, 3) |> d -> truncated(d, lower = 0)
    shape ~ Normal(6, 3) |> d -> truncated(d, lower = 0)

    # Create distribution with current parameters
    d = LogLogisticDistribution(scale, shape)

    # Likelihood
    for i in eachindex(obs_time)
        if fail_status[i] == 0
            # Right censored: P(T > obs_time)
            Turing.@addlogprob! log(survival(d, obs_time[i]))
        else
            # Interval censored: P(lb < T < ub)
            Turing.@addlogprob! log(
                cdf(d, fail_ub[i]) - cdf(d, fail_lb[i])
            )
        end
    end
end
loglogistic_survival (generic function with 2 methods)

survival analysis

Code
library(tidyverse)

failure_data <- read_csv("../../data/failures.csv")

model_data <- list(
  n_meas = nrow(failure_data),
  obs_time = rep(12, nrow(failure_data)),
  fail_lb = failure_data$fail_lb,
  fail_ub = failure_data$fail_ub,
  fail_status = is.finite(failure_data$fail_ub) |> as.integer(),
  n_pred = 101,
  pred_time = seq(from = 0, to = 20, length.out = 101)
)

survival_fit <- survival_model$sample(
  data = model_data,
  chains = 4,
  parallel_chains = parallel::detectCores(),
  seed = 231123,
  iter_warmup = 2000,
  iter_sampling = 2000
)
Running MCMC with 4 chains, at most 16 in parallel...

Chain 1 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 1 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 1 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 1 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 1 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 1 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 1 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 1 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 1 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 1 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 1 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 1 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 1 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 1 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 1 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 1 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 1 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 1 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 1 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 1 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 1 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 1 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 1 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 1 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 1 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 1 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 1 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 2 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 2 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 2 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 2 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 2 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 2 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 2 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 2 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 2 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 2 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 2 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 2 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 2 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 2 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 2 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 2 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 2 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 2 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 2 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 2 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 2 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 2 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 2 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 2 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 2 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 2 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 2 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 2 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 2 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 2 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 2 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 2 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 2 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 3 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 3 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 3 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 3 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 3 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 3 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 3 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 3 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 3 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 3 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 3 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 3 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 3 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 3 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 3 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 3 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 3 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 3 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 3 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 3 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 3 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 3 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 3 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 3 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 3 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 3 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 3 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 3 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 3 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 3 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 3 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 3 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 3 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 3 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 4 Iteration:    1 / 4000 [  0%]  (Warmup) 
Chain 4 Iteration:  100 / 4000 [  2%]  (Warmup) 
Chain 4 Iteration:  200 / 4000 [  5%]  (Warmup) 
Chain 4 Iteration:  300 / 4000 [  7%]  (Warmup) 
Chain 4 Iteration:  400 / 4000 [ 10%]  (Warmup) 
Chain 4 Iteration:  500 / 4000 [ 12%]  (Warmup) 
Chain 4 Iteration:  600 / 4000 [ 15%]  (Warmup) 
Chain 4 Iteration:  700 / 4000 [ 17%]  (Warmup) 
Chain 4 Iteration:  800 / 4000 [ 20%]  (Warmup) 
Chain 4 Iteration:  900 / 4000 [ 22%]  (Warmup) 
Chain 4 Iteration: 1000 / 4000 [ 25%]  (Warmup) 
Chain 4 Iteration: 1100 / 4000 [ 27%]  (Warmup) 
Chain 4 Iteration: 1200 / 4000 [ 30%]  (Warmup) 
Chain 4 Iteration: 1300 / 4000 [ 32%]  (Warmup) 
Chain 4 Iteration: 1400 / 4000 [ 35%]  (Warmup) 
Chain 4 Iteration: 1500 / 4000 [ 37%]  (Warmup) 
Chain 4 Iteration: 1600 / 4000 [ 40%]  (Warmup) 
Chain 4 Iteration: 1700 / 4000 [ 42%]  (Warmup) 
Chain 4 Iteration: 1800 / 4000 [ 45%]  (Warmup) 
Chain 4 Iteration: 1900 / 4000 [ 47%]  (Warmup) 
Chain 4 Iteration: 2000 / 4000 [ 50%]  (Warmup) 
Chain 4 Iteration: 2001 / 4000 [ 50%]  (Sampling) 
Chain 4 Iteration: 2100 / 4000 [ 52%]  (Sampling) 
Chain 4 Iteration: 2200 / 4000 [ 55%]  (Sampling) 
Chain 4 Iteration: 2300 / 4000 [ 57%]  (Sampling) 
Chain 4 Iteration: 2400 / 4000 [ 60%]  (Sampling) 
Chain 4 Iteration: 2500 / 4000 [ 62%]  (Sampling) 
Chain 4 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 4 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 4 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 4 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 4 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 4 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 4 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 4 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 1 Iteration: 2600 / 4000 [ 65%]  (Sampling) 
Chain 1 Iteration: 2700 / 4000 [ 67%]  (Sampling) 
Chain 1 Iteration: 2800 / 4000 [ 70%]  (Sampling) 
Chain 1 Iteration: 2900 / 4000 [ 72%]  (Sampling) 
Chain 1 Iteration: 3000 / 4000 [ 75%]  (Sampling) 
Chain 1 Iteration: 3100 / 4000 [ 77%]  (Sampling) 
Chain 1 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 1 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 1 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 1 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 1 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 1 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 1 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 1 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 1 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 2 Iteration: 3200 / 4000 [ 80%]  (Sampling) 
Chain 2 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 2 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 2 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 2 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 2 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 2 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 2 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 2 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 3 Iteration: 3300 / 4000 [ 82%]  (Sampling) 
Chain 3 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 3 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 3 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 3 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 3 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 3 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 3 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 4 Iteration: 3400 / 4000 [ 85%]  (Sampling) 
Chain 4 Iteration: 3500 / 4000 [ 87%]  (Sampling) 
Chain 4 Iteration: 3600 / 4000 [ 90%]  (Sampling) 
Chain 4 Iteration: 3700 / 4000 [ 92%]  (Sampling) 
Chain 4 Iteration: 3800 / 4000 [ 95%]  (Sampling) 
Chain 4 Iteration: 3900 / 4000 [ 97%]  (Sampling) 
Chain 4 Iteration: 4000 / 4000 [100%]  (Sampling) 
Chain 1 finished in 0.3 seconds.
Chain 2 finished in 0.3 seconds.
Chain 3 finished in 0.3 seconds.
Chain 4 finished in 0.3 seconds.

All 4 chains finished successfully.
Mean chain execution time: 0.3 seconds.
Total execution time: 0.5 seconds.
Code
survival_fit$summary()
# A tibble: 124 × 10
   variable     mean median    sd   mad     q5    q95  rhat ess_bulk ess_tail
   <chr>       <dbl>  <dbl> <dbl> <dbl>  <dbl>  <dbl> <dbl>    <dbl>    <dbl>
 1 lp__       -33.0  -32.7  1.04  0.747 -35.0  -32.0   1.00    3431.    4734.
 2 scale        9.48   9.45 0.693 0.678   8.38  10.6   1.00    5912.    4789.
 3 shape        5.62   5.57 1.09  1.10    3.93   7.50  1.00    5308.    4987.
 4 log_lik[1]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 5 log_lik[2]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 6 log_lik[3]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 7 log_lik[4]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 8 log_lik[5]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
 9 log_lik[6]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
10 log_lik[7]  -1.59  -1.56 0.378 0.369  -2.27  -1.03  1.00    5668.    4719.
# ℹ 114 more rows
Code
import polars as pl, numpy as np
import multiprocessing

failure_data = pl.read_csv("../../data/failures.csv").with_columns([
    pl.col("fail_ub").cast(pl.Float64),
    pl.col("fail_lb").cast(pl.Float64)
])

large_num = 1e10

fail_ub = failure_data["fail_ub"].to_numpy().copy()
fail_ub[~np.isfinite(fail_ub)] = large_num

model_data = {
    "n_meas": failure_data.shape[0],
    "obs_time": [12] * failure_data.shape[0],
    "fail_lb": failure_data["fail_lb"].to_numpy(),
    "fail_ub": fail_ub,
    "fail_status": (failure_data["fail_ub"].is_finite().cast(pl.Int64)).to_numpy(),
    "n_pred": 101,
    "pred_time": np.linspace(start = 0, stop = 20, num = 101)
}

survival_fit = survival_model.sample(
  data = model_data,
  chains = 4,
  parallel_chains = 1,
  seed = 231123,
  iter_warmup = 2000,
  iter_sampling = 2000
)
                                                                                                                                                                                                                                                                                                                                

INFO:cmdstanpy:CmdStan start processing

chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status


chain 3 |          | 00:00 Status



chain 4 |          | 00:00 Status
chain 1 |#####4    | 00:00 Iteration: 2001 / 4000 [ 50%]  (Sampling)

chain 2 |2         | 00:00 Status

chain 2 |######1   | 00:00 Iteration: 2300 / 4000 [ 57%]  (Sampling)

chain 2 |#########5| 00:00 Iteration: 3700 / 4000 [ 92%]  (Sampling)


chain 3 |2         | 00:00 Status


chain 3 |######1   | 00:00 Iteration: 2300 / 4000 [ 57%]  (Sampling)


chain 3 |#########5| 00:00 Iteration: 3700 / 4000 [ 92%]  (Sampling)



chain 4 |2         | 00:00 Status



chain 4 |######1   | 00:01 Iteration: 2300 / 4000 [ 57%]  (Sampling)



chain 4 |#########5| 00:01 Iteration: 3700 / 4000 [ 92%]  (Sampling)
chain 1 |##########| 00:01 Sampling completed                       

chain 2 |##########| 00:01 Sampling completed                       

chain 3 |##########| 00:01 Sampling completed                       

chain 4 |##########| 00:01 Sampling completed                       
INFO:cmdstanpy:CmdStan done processing.
Code
survival_fit.summary()
                   Mean     MCSE  StdDev     5%  ...   95%   N_Eff  N_Eff/s  R_hat
name                                             ...                              
lp__             -33.00  0.01800   1.000 -35.00  ... -32.0  3500.0   4000.0    1.0
scale              9.50  0.00910   0.690   8.40  ...  11.0  5800.0   6600.0    1.0
shape              5.60  0.01500   1.100   3.90  ...   7.5  5400.0   6100.0    1.0
log_lik[1]        -1.60  0.00500   0.380  -2.30  ...  -1.0  5750.0   6549.0    1.0
log_lik[2]        -1.60  0.00500   0.380  -2.30  ...  -1.0  5750.0   6549.0    1.0
...                 ...      ...     ...    ...  ...   ...     ...      ...    ...
p_fail_pred[97]    0.97  0.00033   0.022   0.93  ...   1.0  4457.0   5076.0    1.0
p_fail_pred[98]    0.98  0.00032   0.021   0.93  ...   1.0  4452.0   5071.0    1.0
p_fail_pred[99]    0.98  0.00031   0.021   0.94  ...   1.0  4448.0   5066.0    1.0
p_fail_pred[100]   0.98  0.00030   0.020   0.94  ...   1.0  4444.0   5062.0    1.0
p_fail_pred[101]   0.98  0.00029   0.019   0.94  ...   1.0  4441.0   5058.0    1.0

[124 rows x 9 columns]
Code
using CSV, DataFrames, DataFramesMeta

failure_data = CSV.read("../../data/failures.csv", DataFrame)

survival_fit = loglogistic_survival(
    repeat([12.0], nrow(failure_data)),
    failure_data.fail_lb,
    failure_data.fail_ub,
    isfinite.(failure_data.fail_ub) |> x -> Int.(x)
) |> model -> sample(MersenneTwister(231123), model, NUTS(), MCMCThreads(), 2000, 4)

survival_fit
Code
# echo: false
survival_fit
Chains MCMC chain (2000×14×4 Array{Float64, 3}):

Iterations        = 1001:1:3000
Number of chains  = 4
Samples per chain = 2000
Wall duration     = 14.72 seconds
Compute duration  = 11.58 seconds
parameters        = scale, shape
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std      mcse    ess_bulk    ess_tail      rhat   ⋯
      Symbol   Float64   Float64   Float64     Float64     Float64   Float64   ⋯

       scale    9.4688    0.6886    0.0090   5854.0791   4655.8635    1.0006   ⋯
       shape    5.6014    1.0963    0.0143   5854.6561   5483.7428    1.0012   ⋯
                                                                1 column omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

       scale    8.1597    9.0166    9.4565    9.9064   10.9001
       shape    3.6211    4.8263    5.5506    6.3210    7.8421

expected information gain

Code
params <- survival_fit$metadata()$model_params

dist_params <- params[grep(pattern = "scale|shape", x = params)]
pred_params <- params[grep(pattern = "pred", x = params)]

post_pred <- survival_fit |>
  DomDF::tidy_mcmc_draws(params = pred_params) |>
  mutate(time = rep(x = model_data$pred_time, 
                    each = survival_fit$metadata()$iter_sampling * length(survival_fit$metadata()$id)))
Code
import pandas as pd

params = survival_fit.column_names

pred_params = [p for p in params if "p_fail_pred" in p]

draws_df = survival_fit.draws_pd(vars="p_fail_pred")

n_chains = survival_fit.chains
n_draws = survival_fit.num_draws_sampling

draws_df['Chain'] = [chain for chain in range(1, n_chains + 1) for _ in range(n_draws)]
draws_df['Iteration'] = list(range(1, n_draws + 1)) * n_chains

df_long = draws_df.melt(id_vars=['Chain', 'Iteration'],
                        var_name='Parameter',
                        value_name='value')

n_preds = len(pred_params)
mapping_df = pd.DataFrame({
    "Parameter": pred_params,
    "time": model_data["pred_time"][:n_preds]
})

df_long = df_long.merge(mapping_df, on='Parameter', how='left')

post_pred = pl.from_pandas(df_long)
Code
pred_times = 0:0.2:20

post_pred = survival_fit |> DataFrame |>
    df -> @rselect(df, :iteration, :chain, :scale, :shape) |>
    df -> @rtransform(df, :pr_fail_pred = cdf.(LogLogisticDistribution(:scale, :shape), pred_times)) |>
    df -> df.pr_fail_pred |>
    preds -> [getindex.(preds, i) for i in 1:length(pred_times)] |>
    preds -> DataFrame(
        pred_time = pred_times, 
        pr_fail_pred = preds)

expected information gain

can be computationally intensive

expected information gain

  • quantify uncertainty in posterior predictions
  • identify prospetive data collection options
  • generate all possible outcome scenarios
    • here (helpfully): failure or no failure
  • for each outcome:
    • simpulate the data collection and re-fit the model
    • quantify uncertainty in the new posterior predictions
    • find the difference (reduction in uncertainty with the new data)
    • weight the reduction by the probability of the outcome
  • compare the expected “information gain” to rank order data collection options

measures of uncertainty

  • entropy?
  • log-likelihood?
  • kernel density estimation?
  • variance?
Code
post_pred |> head()
# A tibble: 6 × 5
  Parameter      Chain Iteration value  time
  <chr>          <int>     <int> <dbl> <dbl>
1 p_fail_pred[1]     1         1     0     0
2 p_fail_pred[1]     1         2     0     0
3 p_fail_pred[1]     1         3     0     0
4 p_fail_pred[1]     1         4     0     0
5 p_fail_pred[1]     1         5     0     0
6 p_fail_pred[1]     1         6     0     0
Code
estimate_uncertainty <- function(posterior = post_pred) {
  posterior |>
    group_by(time) |>
    summarise(uncertainty_base = var(value))
}

estimate_uncertainty() |> head()
# A tibble: 6 × 2
   time uncertainty_base
  <dbl>            <dbl>
1   0           0       
2   0.2         4.36e-12
3   0.4         1.32e-10
4   0.6         1.09e- 9
5   0.8         5.34e- 9
6   1           1.93e- 8
Code
post_pred.head()
shape: (5, 5)
Chain Iteration Parameter value time
i64 i64 str f64 f64
1 1 "p_fail_pred[1]" 0.0 0.0
1 2 "p_fail_pred[1]" 0.0 0.0
1 3 "p_fail_pred[1]" 0.0 0.0
1 4 "p_fail_pred[1]" 0.0 0.0
1 5 "p_fail_pred[1]" 0.0 0.0
Code
def estimate_uncertainty(posterior = post_pred):
    return (posterior
            .group_by("time")
            .agg(uncertainty=pl.col("value").var())
            .sort("time"))

estimate_uncertainty().head()
shape: (5, 2)
time uncertainty
f64 f64
0.0 0.0
0.2 4.3591e-12
0.4 1.3194e-10
0.6 1.0948e-9
0.8 5.3361e-9
Code
first(post_pred, 6)

function estimate_uncertainty(posterior::DataFrame = post_pred)
    posterior |>
        df -> flatten(df, :pr_fail_pred) |>
        df -> groupby(df, :pred_time) |>
        gdf -> combine(gdf, :pr_fail_pred => var => :uncertainty)
end
6×2 DataFrame
 Row │ pred_time  pr_fail_pred
     │ Float64    Array…
─────┼──────────────────────────────────────────────
   1 │       0.0  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0…
   2 │       0.2  [2.68764e-9, 4.52591e-9, 1.20061…
   3 │       0.4  [8.86347e-8, 1.51142e-7, 2.13236…
   4 │       0.6  [6.85036e-7, 1.17674e-6, 1.14747…
   5 │       0.8  [2.92305e-6, 5.04732e-6, 3.78708…
   6 │       1.0  [9.00733e-6, 1.56161e-5, 9.56126…

expected information gain

Code
estimate_information_gain <- function(proposed_time) {
  # we need new datasets (hypothesising our next data point)
  fail_data <- model_data -> no_fail_data
  
  # case A: we observe a failure
  fail_data$n_meas <- fail_data$n_meas + 1
  fail_data$obs_time <- c(fail_data$obs_time, proposed_time)
  fail_data$fail_lb <- c(fail_data$fail_lb, proposed_time - 1.5)
  fail_data$fail_ub <- c(fail_data$fail_ub, proposed_time)
  fail_data$fail_status <- c(fail_data$fail_status, 1)

  # case B: we do not observe a failure
  no_fail_data$n_meas <- no_fail_data$n_meas + 1
  no_fail_data$obs_time <- c(no_fail_data$obs_time, proposed_time)
  no_fail_data$fail_lb <- c(no_fail_data$fail_lb, proposed_time)
  no_fail_data$fail_ub <- c(no_fail_data$fail_ub, Inf)
  no_fail_data$fail_status <- c(no_fail_data$fail_status, 0)

  # re-fitting our models for each possible outcome
  fail_fit <- survival_model$sample(
    data = fail_data,
    chains = 4,
    parallel_chains = parallel::detectCores(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )

  no_fail_fit <- survival_model$sample(
    data = no_fail_data,
    chains = 4,
    parallel_chains = parallel::detectCores(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )

  # quantify uncertainty in the new predictions
  base_uncertainties <- estimate_uncertainty()
    
  fail_uncertainties <- fail_fit |>
    DomDF::tidy_mcmc_draws(params = pred_params) |>
    mutate(time = rep(x = model_data$pred_time, 
           each = fail_fit$metadata()$iter_sampling * length(fail_fit$metadata()$id))) |>
    estimate_uncertainty() |> rename(uncertainty_fail = uncertainty_base)
    
  no_fail_uncertainties <- no_fail_fit |>
    DomDF::tidy_mcmc_draws(params = pred_params) |>
    mutate(time = rep(x = model_data$pred_time, 
           each = no_fail_fit$metadata()$iter_sampling * length(no_fail_fit$metadata()$id))) |>
    estimate_uncertainty() |> rename(uncertainty_no_fail = uncertainty_base)
    
  # what are the prior probabilities of each outcome?
  p_fail <- post_pred |>
    filter(abs(time - proposed_time) == min(abs(time - proposed_time))) |>
    summarise(p = mean(value)) |>
    pull(p)
    
  information_gains <- base_uncertainties |>
    left_join(fail_uncertainties, by = "time") |>
    left_join(no_fail_uncertainties, by = "time") |>
    mutate(
      # calculate a weighted uncertainty reduction
      weighted_reduction = pmax(0, (uncertainty_base - uncertainty_fail)) * p_fail +
                           pmax(0, (uncertainty_base - uncertainty_no_fail)) * (1 - p_fail)

    )
    
  # return the expected information gain
  return(information_gains$weighted_reduction |> sum())
}
Code
import copy

def estimate_information_gain(proposed_time):
  fail_data = copy.deepcopy(model_data)
  no_fail_data = copy.deepcopy(model_data)
  
  fail_data["obs_time"] = model_data["obs_time"].tolist() if hasattr(model_data["obs_time"], "tolist") else list(model_data["obs_time"])
  fail_data["fail_lb"]   = model_data["fail_lb"].tolist() if hasattr(model_data["fail_lb"], "tolist") else list(model_data["fail_lb"])
  fail_data["fail_ub"]   = model_data["fail_ub"].tolist() if hasattr(model_data["fail_ub"], "tolist") else list(model_data["fail_ub"])
  fail_data["fail_status"] = model_data["fail_status"].tolist() if hasattr(model_data["fail_status"], "tolist") else list(model_data["fail_status"])

  no_fail_data["obs_time"] = model_data["obs_time"].tolist() if hasattr(model_data["obs_time"], "tolist") else list(model_data["obs_time"])
  no_fail_data["fail_lb"]   = model_data["fail_lb"].tolist() if hasattr(model_data["fail_lb"], "tolist") else list(model_data["fail_lb"])
  no_fail_data["fail_ub"]   = model_data["fail_ub"].tolist() if hasattr(model_data["fail_ub"], "tolist") else list(model_data["fail_ub"])
  no_fail_data["fail_status"] = model_data["fail_status"].tolist() if hasattr(model_data["fail_status"], "tolist") else list(model_data["fail_status"])

  fail_data["n_meas"] = model_data["n_meas"] + 1
  fail_data["obs_time"].append(proposed_time)
  fail_data["fail_lb"].append(proposed_time - 1.5)
  fail_data["fail_ub"].append(proposed_time)
  fail_data["fail_status"].append(1)

  no_fail_data["n_meas"] = model_data["n_meas"] + 1
  no_fail_data["obs_time"].append(proposed_time)
  no_fail_data["fail_lb"].append(proposed_time)
  no_fail_data["fail_ub"].append(large_num)  
  no_fail_data["fail_status"].append(0)
    
  fail_fit = survival_model.sample(
    data = fail_data,
    chains = 4,
    parallel_chains = multiprocessing.cpu_count(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )
  
  no_fail_fit = survival_model.sample(
    data = no_fail_data,
    chains = 4,
    parallel_chains = multiprocessing.cpu_count(),
    seed = 231123,
    iter_warmup = 2000,
    iter_sampling = 2000
  )
    
  window = 2.0
  
  base_uncertainties = (
    post_pred
    .filter(abs(pl.col("time") - proposed_time) <= window)
    .group_by("time")
    .agg(uncertainty_base=pl.col("value").var())
    .sort("time")
  )

  fail_post = (
    process_mcmc_draws(fail_fit, pred_params)
    .filter((pl.col("time") - proposed_time).abs() <= window)
    .group_by("time")
    .agg(pl.col("value").var().alias("uncertainty_fail"))
    .sort("time")
  )
  
  no_fail_post = (
    process_mcmc_draws(no_fail_fit, pred_params)
    .filter((pl.col("time") - proposed_time).abs() <= window)
    .group_by("time")
    .agg(pl.col("value").var().alias("uncertainty_no_fail"))
    .sort("time")
  )
    
  min_diff = (
    post_pred
    .select((pl.col("time") - proposed_time).abs().alias("diff"))
    .select(pl.col("diff").min())
    .item()
  )
    
  p_fail = (
    post_pred
    .filter((pl.col("time") - proposed_time).abs() == min_diff)
    .select(pl.col("value").mean().alias("p"))
    .item()
  )
    
  information_gains = (
    base_uncertainties
    .join(fail_post, on="time", how="left")
    .join(no_fail_post, on="time", how="left")
    .with_columns(
        weighted_reduction=(
            pl.when(pl.col("uncertainty_base") - pl.col("uncertainty_fail") > 0)
              .then(pl.col("uncertainty_base") - pl.col("uncertainty_fail"))
              .otherwise(0) * p_fail +
            pl.when(pl.col("uncertainty_base") - pl.col("uncertainty_no_fail") > 0)
              .then(pl.col("uncertainty_base") - pl.col("uncertainty_no_fail"))
              .otherwise(0) * (1 - p_fail)
        )
    )
  )
    
  # Return the total information gain (sum over weighted_reduction)
  total_gain = information_gains.select(pl.col("weighted_reduction")).sum().item()
  return total_gain
Code
function estimate_information_gain(proposed_time::Float64)
    
    new_comp_id = maximum(failure_data.component_id) + 1
    
    scenarios = (
        fail = (
            data = deepcopy(failure_data) |>
                    model_data -> push!(model_data, (component_id = new_comp_id, 
                            fail_lb = proposed_time - 1.5, 
                            fail_ub = proposed_time)),
            name = :uncertainty_fail
        ),
        no_fail = (
            data = deepcopy(failure_data) |>
                    model_data -> push!(model_data, (component_id = new_comp_id, 
                            fail_lb = proposed_time, 
                            fail_ub = Inf)),
            name = :uncertainty_no_fail
        )
    )
    
    function process_scenario(scenario)
        observation_times = vcat(repeat([12.0], nrow(failure_data)), [proposed_time])
        
        loglogistic_survival(
            observation_times,
            scenario.data.fail_lb,
            scenario.data.fail_ub,
            isfinite.(scenario.data.fail_ub) |> x -> Int.(x)
        ) |>
        model -> sample(MersenneTwister(231123), model, NUTS(), MCMCThreads(), 2000, 4) |>
        DataFrame |>
        df -> @rselect(df, :iteration, :chain, :scale, :shape) |>
        df -> @rtransform(df, :pr_fail_pred = cdf.(LogLogisticDistribution(:scale, :shape), pred_times)) |>
        df -> df.pr_fail_pred |>
        preds -> [getindex.(preds, i) for i in 1:length(pred_times)] |>
        preds -> DataFrame(pred_time = pred_times, pr_fail_pred = preds) |>
        estimate_uncertainty |>
        df -> rename(df, :uncertainty => scenario.name)
    end
    
    fail_uncertainties = process_scenario(scenarios.fail)
    no_fail_uncertainties = process_scenario(scenarios.no_fail)
    base_uncertainties = estimate_uncertainty()
    
    p_fail = post_pred |>
        df -> @rsubset(df, abs(:pred_time - proposed_time) == 
                      minimum(abs.(:pred_time .- proposed_time))) |>
        df -> df.pr_fail_pred |> first |> mean
    
    leftjoin(base_uncertainties, fail_uncertainties, on = :pred_time) |>
        df -> leftjoin(df, no_fail_uncertainties, on = :pred_time) |>
        df -> @rtransform(df, :weighted_reduction = 
              max(0, (:uncertainty - :uncertainty_fail)) * p_fail +
              max(0, (:uncertainty - :uncertainty_no_fail)) * (1 - p_fail)) |>
        df -> df.weighted_reduction |> sum
end

expected information gain

experimental design

  • what do we want to achieve with data collection?
    • reduce uncertainty in predictions?
    • test a hypothesis?
    • support decision-making? (see “value of information analysis”)

break?